{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a Model on Encrypted Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# python 3.7 is required\n",
    "assert sys.version_info[0] == 3 and sys.version_info[1] == 7, \"python 3.7 is required\"\n",
    "\n",
    "import crypten\n",
    "crypten.init()\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits = torchvision.datasets.MNIST(root='/tmp/data', \n",
    "                                           train=True, \n",
    "                                           transform=torchvision.transforms.ToTensor(),\n",
    "                                           download=True)\n",
    "\n",
    "digits_test = torchvision.datasets.MNIST(root='/tmp/data', \n",
    "                                           train=True, \n",
    "                                           transform=torchvision.transforms.ToTensor(),\n",
    "                                           download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x13193bb90>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAN9klEQVR4nO3df4xV9ZnH8c+zWP6QojBrOhKKSyEGg8ZON4gbl6w1hvojGhw1TSexoZE4/YNJaLIhNewf1WwwZBU2SzTNTKMWNl1qEzUgaQouoOzGhDgiKo5LdQ2mTEaowZEf/mCHefaPezBTnfu9w7nn3nOZ5/1Kbu6957nnnicnfDi/7pmvubsATH5/VXYDAJqDsANBEHYgCMIOBEHYgSAuaubCzIxT/0CDubuNN72uLbuZ3Wpmh8zsPTN7sJ7vAtBYlvc6u5lNkfRHSUslHZH0qqQudx9IzMOWHWiwRmzZF0t6z93fd/czkn4raVkd3weggeoJ+2xJfxrz/kg27S+YWbeZ9ZtZfx3LAlCnhp+gc/c+SX0Su/FAmerZsg9KmjPm/bezaQBaUD1hf1XSlWb2HTObKulHkrYV0xaAouXejXf3ETPrkbRD0hRJT7n724V1BqBQuS+95VoYx+xAwzXkRzUALhyEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBJF7yGZcGKZMmZKsX3rppQ1dfk9PT9XaxRdfnJx3wYIFyfrKlSuT9ccee6xqraurKznv559/nqyvW7cuWX/44YeT9TLUFXYzOyzppKSzkkbcfVERTQEoXhFb9pvc/aMCvgdAA3HMDgRRb9hd0k4ze83Musf7gJl1m1m/mfXXuSwAdah3N36Juw+a2bckvWhm/+Pue8d+wN37JPVJkpl5ncsDkFNdW3Z3H8yej0l6XtLiIpoCULzcYTezaWY2/dxrST+QdLCoxgAUq57d+HZJz5vZue/5D3f/QyFdTTJXXHFFsj516tRk/YYbbkjWlyxZUrU2Y8aM5Lz33HNPsl6mI0eOJOsbN25M1js7O6vWTp48mZz3jTfeSNZffvnlZL0V5Q67u78v6bsF9gKggbj0BgRB2IEgCDsQBGEHgiDsQBDm3rwftU3WX9B1dHQk67t3707WG32baasaHR1N1u+///5k/dSpU7mXPTQ0lKx//PHHyfqhQ4dyL7vR3N3Gm86WHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeC4Dp7Adra2pL1ffv2Jevz5s0rsp1C1ep9eHg4Wb/pppuq1s6cOZOcN+rvD+rFdXYgOMIOBEHYgSAIOxAEYQeCIOxAEIQdCIIhmwtw/PjxZH316tXJ+h133JGsv/7668l6rT+pnHLgwIFkfenSpcn66dOnk/Wrr766am3VqlXJeVEstuxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EAT3s7eASy65JFmvNbxwb29v1dqKFSuS8953333J+pYtW5J1tJ7c97Ob2VNmdszMDo6Z1mZmL5rZu9nzzCKbBVC8iezG/1rSrV+Z9qCkXe5+paRd2XsALaxm2N19r6Sv/h50maRN2etNku4quC8ABcv72/h2dz83WNaHktqrfdDMuiV151wOgILUfSOMu3vqxJu790nqkzhBB5Qp76W3o2Y2S5Ky52PFtQSgEfKGfZuk5dnr5ZK2FtMOgEapuRtvZlskfV/SZWZ2RNIvJK2T9DszWyHpA0k/bGSTk92JEyfqmv+TTz7JPe8DDzyQrD/zzDPJeq0x1tE6aobd3buqlG4uuBcADcTPZYEgCDsQBGEHgiDsQBCEHQiCW1wngWnTplWtvfDCC8l5b7zxxmT9tttuS9Z37tyZrKP5GLIZCI6wA0EQdiAIwg4EQdiBIAg7EARhB4LgOvskN3/+/GR9//79yfrw8HCyvmfPnmS9v7+/au2JJ55IztvMf5uTCdfZgeAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIrrMH19nZmaw//fTTyfr06dNzL3vNmjXJ+ubNm5P1oaGhZD0qrrMDwRF2IAjCDgRB2IEgCDsQBGEHgiDsQBBcZ0fSNddck6xv2LAhWb/55vyD/fb29ibra9euTdYHBwdzL/tClvs6u5k9ZWbHzOzgmGkPmdmgmR3IHrcX2SyA4k1kN/7Xkm4dZ/q/untH9vh9sW0BKFrNsLv7XknHm9ALgAaq5wRdj5m9me3mz6z2ITPrNrN+M6v+x8gANFzesP9S0nxJHZKGJK2v9kF373P3Re6+KOeyABQgV9jd/ai7n3X3UUm/krS42LYAFC1X2M1s1pi3nZIOVvssgNZQ8zq7mW2R9H1Jl0k6KukX2fsOSS7psKSfunvNm4u5zj75zJgxI1m/8847q9Zq3StvNu7l4i/t3r07WV+6dGmyPllVu85+0QRm7Bpn8pN1dwSgqfi5LBAEYQeCIOxAEIQdCIKwA0FwiytK88UXXyTrF12Uvlg0MjKSrN9yyy1Vay+99FJy3gsZf0oaCI6wA0EQdiAIwg4EQdiBIAg7EARhB4KoedcbYrv22muT9XvvvTdZv+6666rWal1Hr2VgYCBZ37t3b13fP9mwZQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBILjOPsktWLAgWe/p6UnW77777mT98ssvP++eJurs2bPJ+tBQ+q+Xj46OFtnOBY8tOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwXX2C0Cta9ldXeMNtFtR6zr63Llz87RUiP7+/mR97dq1yfq2bduKbGfSq7llN7M5ZrbHzAbM7G0zW5VNbzOzF83s3ex5ZuPbBZDXRHbjRyT9o7svlPR3klaa2UJJD0ra5e5XStqVvQfQomqG3d2H3H1/9vqkpHckzZa0TNKm7GObJN3VqCYB1O+8jtnNbK6k70naJ6nd3c/9OPlDSe1V5umW1J2/RQBFmPDZeDP7pqRnJf3M3U+MrXlldMhxB2109z53X+Tui+rqFEBdJhR2M/uGKkH/jbs/l00+amazsvosScca0yKAItTcjTczk/SkpHfcfcOY0jZJyyWty563NqTDSaC9fdwjnC8tXLgwWX/88ceT9auuuuq8eyrKvn37kvVHH320am3r1vQ/GW5RLdZEjtn/XtKPJb1lZgeyaWtUCfnvzGyFpA8k/bAxLQIoQs2wu/t/Sxp3cHdJNxfbDoBG4eeyQBCEHQiCsANBEHYgCMIOBMEtrhPU1tZWtdbb25uct6OjI1mfN29erp6K8MorryTr69evT9Z37NiRrH/22Wfn3RMagy07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgQR5jr79ddfn6yvXr06WV+8eHHV2uzZs3P1VJRPP/20am3jxo3JeR955JFk/fTp07l6Quthyw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQYS5zt7Z2VlXvR4DAwPJ+vbt25P1kZGRZD11z/nw8HByXsTBlh0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgjB3T3/AbI6kzZLaJbmkPnf/NzN7SNIDkv6cfXSNu/++xnelFwagbu4+7qjLEwn7LEmz3H2/mU2X9Jqku1QZj/2Uuz820SYIO9B41cI+kfHZhyQNZa9Pmtk7ksr90ywAztt5HbOb2VxJ35O0L5vUY2ZvmtlTZjazyjzdZtZvZv11dQqgLjV347/8oNk3Jb0saa27P2dm7ZI+UuU4/p9V2dW/v8Z3sBsPNFjuY3ZJMrNvSNouaYe7bxinPlfSdne/psb3EHagwaqFveZuvJmZpCclvTM26NmJu3M6JR2st0kAjTORs/FLJP2XpLckjWaT10jqktShym78YUk/zU7mpb6LLTvQYHXtxheFsAONl3s3HsDkQNiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQii2UM2fyTpgzHvL8umtaJW7a1V+5LoLa8ie/ubaoWm3s/+tYWb9bv7otIaSGjV3lq1L4ne8mpWb+zGA0EQdiCIssPeV/LyU1q1t1btS6K3vJrSW6nH7ACap+wtO4AmIexAEKWE3cxuNbNDZvaemT1YRg/VmNlhM3vLzA6UPT5dNobeMTM7OGZam5m9aGbvZs/jjrFXUm8Pmdlgtu4OmNntJfU2x8z2mNmAmb1tZquy6aWuu0RfTVlvTT9mN7Mpkv4oaamkI5JeldTl7gNNbaQKMzssaZG7l/4DDDP7B0mnJG0+N7SWmf2LpOPuvi77j3Kmu/+8RXp7SOc5jHeDeqs2zPhPVOK6K3L48zzK2LIvlvSeu7/v7mck/VbSshL6aHnuvlfS8a9MXiZpU/Z6kyr/WJquSm8twd2H3H1/9vqkpHPDjJe67hJ9NUUZYZ8t6U9j3h9Ra4337pJ2mtlrZtZddjPjaB8zzNaHktrLbGYcNYfxbqavDDPeMusuz/Dn9eIE3dctcfe/lXSbpJXZ7mpL8soxWCtdO/2lpPmqjAE4JGl9mc1kw4w/K+ln7n5ibK3MdTdOX01Zb2WEfVDSnDHvv51NawnuPpg9H5P0vCqHHa3k6LkRdLPnYyX38yV3P+ruZ919VNKvVOK6y4YZf1bSb9z9uWxy6etuvL6atd7KCPurkq40s++Y2VRJP5K0rYQ+vsbMpmUnTmRm0yT9QK03FPU2Scuz18slbS2xl7/QKsN4VxtmXCWvu9KHP3f3pj8k3a7KGfn/lfRPZfRQpa95kt7IHm+X3ZukLars1v2fKuc2Vkj6a0m7JL0r6T8ltbVQb/+uytDeb6oSrFkl9bZElV30NyUdyB63l73uEn01Zb3xc1kgCE7QAUEQdiAIwg4EQdiBIAg7EARhB4Ig7EAQ/w8ie3GmjcGk5QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(digits[0][0][0], cmap='gray', interpolation='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label for image is  5\n"
     ]
    }
   ],
   "source": [
    "print(\"label for image is \", digits[0][1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess Into Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def take_samples(digits, n_samples=1000):\n",
    "    \"\"\"Returns images and labels based on sample size\"\"\"\n",
    "    images, labels = [], []\n",
    "\n",
    "    for i, digit in enumerate(digits):\n",
    "        if i == n_samples:\n",
    "            break\n",
    "        image, label = digit\n",
    "        images.append(image)\n",
    "        label_one_hot = torch.nn.functional.one_hot(torch.tensor(label), 10)\n",
    "        labels.append(label_one_hot)\n",
    "\n",
    "    images = torch.cat(images)\n",
    "    labels = torch.stack(labels)\n",
    "    return images, labels\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "images, labels = take_samples(digits, n_samples=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100, 28, 28])\n",
      "torch.Size([100, 10])\n"
     ]
    }
   ],
   "source": [
    "print(images.shape)\n",
    "print(labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_enc = crypten.cryptensor(images)\n",
    "labels_enc = crypten.cryptensor(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPCTensor(\n",
       "\t_tensor=tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,   771,  4626,  4626,  4626, 32382, 34952, 44975,  6682,\n",
       "         42662, 65536, 63479, 32639,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,  7710,  9252,\n",
       "         24158, 39578, 43690, 65021, 65021, 65021, 65021, 65021, 57825, 44204,\n",
       "         65021, 62194, 50115, 16448,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0, 12593, 61166, 65021,\n",
       "         65021, 65021, 65021, 65021, 65021, 65021, 65021, 64507, 23901, 21074,\n",
       "         21074, 14392, 10023,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,  4626, 56283, 65021,\n",
       "         65021, 65021, 65021, 65021, 50886, 46774, 63479, 61937,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0, 20560, 40092,\n",
       "         27499, 65021, 65021, 52685,  2827,     0, 11051, 39578,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,  3598,\n",
       "           257, 39578, 65021, 23130,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0, 35723, 65021, 48830,   514,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,  2827, 48830, 65021, 17990,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,  8995, 61937, 57825, 41120, 27756,   257,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0, 20817, 61680, 65021, 65021, 30583,  6425,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0, 11565, 47802, 65021, 65021, 38550,  6939,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,  4112, 23901, 64764, 65021, 48059,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0, 63993, 65021, 63993,\n",
       "         16448,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0, 11822, 33410, 47031, 65021, 65021, 53199,\n",
       "           514,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0, 10023, 38036, 58853, 65021, 65021, 65021, 64250, 46774,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "          6168, 29298, 56797, 65021, 65021, 65021, 65021, 51657, 20046,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,  5911, 16962,\n",
       "         54741, 65021, 65021, 65021, 65021, 50886, 20817,   514,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,  4626, 43947, 56283, 65021,\n",
       "         65021, 65021, 65021, 50115, 20560,  2313,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0, 14135, 44204, 58082, 65021, 65021, 65021,\n",
       "         65021, 62708, 34181,  2827,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0, 34952, 65021, 65021, 65021, 54484, 34695,\n",
       "         33924,  4112,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0],\n",
       "        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0]])\n",
       "\tplain_text=HIDDEN\n",
       "\tptype=ptype.arithmetic\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images_enc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test set\n",
    "images_test, labels_test = take_samples(digits_test, n_samples=20)\n",
    "images_test_enc = crypten.cryptensor(images_test)\n",
    "labels_test_enc = crypten.cryptensor(labels_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logistic Regression Model\n",
    "\n",
    "(multiclass logistic regression)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(crypten.nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # images are 28x28 pixels\n",
    "        self.linear = crypten.nn.Linear(28 * 28, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression().encrypt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<crypten.autograd_cryptensor.AutogradCrypTensor object at 0x12908bbd0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(images_enc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Model on Encrypted Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, X, y, epochs=10, learning_rate=0.05):\n",
    "    criterion = crypten.nn.CrossEntropyLoss()\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        model.zero_grad()\n",
    "        output = model(X)\n",
    "        loss = criterion(output, y)\n",
    "        print(f\"epoch {epoch} loss: {loss.get_plain_text()}\")\n",
    "        loss.backward()\n",
    "        model.update_parameters(learning_rate)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 loss: 2.332916259765625\n",
      "epoch 1 loss: 2.2252349853515625\n",
      "epoch 2 loss: 2.1306304931640625\n",
      "epoch 3 loss: 2.0459747314453125\n",
      "epoch 4 loss: 1.968963623046875\n",
      "epoch 5 loss: 1.89739990234375\n",
      "epoch 6 loss: 1.8313140869140625\n",
      "epoch 7 loss: 1.7688751220703125\n",
      "epoch 8 loss: 1.7100067138671875\n",
      "epoch 9 loss: 1.6540679931640625\n"
     ]
    }
   ],
   "source": [
    "model = train_model(model, images_enc, labels_enc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decrypt Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = model(images_enc[3].unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction.get_plain_text().argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x131b099d0>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAMb0lEQVR4nO3db4gc9R3H8c8ntkWIotHQM9rUtMUnUmwsQQo9SoppiCIkfRKaByXS0vNBlQoVIlaoUgqhVouIFq5o/pTWUog2oZS2NkRtCZacksaoidqQYI54VxGpeZTqfftgJ3LG29lzZ2Znk+/7Bcfuznd35suQT+bf7vwcEQJw7lvQdgMABoOwA0kQdiAJwg4kQdiBJD4xyIXZ5tQ/0LCI8FzTK23Zba+xfdj267bvrDIvAM1yv9fZbZ8n6VVJ35B0XNI+SRsi4uWSz7BlBxrWxJb9OkmvR8SRiDgl6XeS1laYH4AGVQn7FZLemPX6eDHtQ2yP2Z6wPVFhWQAqavwEXUSMSxqX2I0H2lRlyz4paems158ppgEYQlXCvk/SVbY/Z/tTkr4laVc9bQGoW9+78RHxnu1bJf1F0nmSHouIl2rrDECt+r701tfCOGYHGtfIl2oAnD0IO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLvIZuBpt19992l9Xvvvbe0vmBB923ZypUrSz/7zDPPlNbPRpXCbvuopHclvS/pvYhYUUdTAOpXx5b96xHxVg3zAdAgjtmBJKqGPST91fbztsfmeoPtMdsTticqLgtABVV340cjYtL2pyU9ZftQRDw7+w0RMS5pXJJsR8XlAehTpS17REwWj9OSnpR0XR1NAahf32G3vdD2haefS1ot6WBdjQGoV5Xd+BFJT9o+PZ/fRsSfa+kKKdx8882l9U2bNpXWZ2Zm+l52RL4jyr7DHhFHJH2pxl4ANIhLb0AShB1IgrADSRB2IAnCDiTBT1zRmiuvvLK0fv755w+okxzYsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAElxnR6NWrVrVtXbbbbdVmvehQ4dK6zfddFPX2tTUVKVln43YsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAElxnRyWjo6Ol9S1btnStXXTRRZWWfd9995XWjx07Vmn+5xq27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBNfZUcnGjRtL65dffnnf83766adL69u3b+973hn13LLbfsz2tO2Ds6ZdYvsp268Vj4uabRNAVfPZjd8qac0Z0+6UtDsirpK0u3gNYIj1DHtEPCvp7TMmr5W0rXi+TdK6mvsCULN+j9lHIuJE8fxNSSPd3mh7TNJYn8sBUJPKJ+giImxHSX1c0rgklb0PQLP6vfQ2ZXuJJBWP0/W1BKAJ/YZ9l6TT11w2StpZTzsAmuKI8j1r249LWilpsaQpST+W9AdJv5f0WUnHJK2PiDNP4s01L3bjzzKLFy8urfe6//rMzEzX2jvvvFP62fXr15fW9+zZU1rPKiI81/Sex+wRsaFL6fpKHQEYKL4uCyRB2IEkCDuQBGEHkiDsQBL8xDW5ZcuWldZ37NjR2LIfeuih0jqX1urFlh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA6e3Jr1px5L9EPu+aaayrNf/fu3V1rDz74YKV54+Nhyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfS8lXStC+NW0gO3bl35MHxbt24trS9cuLC0vnfv3tJ62e2ge92GGv3pditptuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kAS/Zz8HlN37vcn7vkvSkSNHSutcSx8ePbfsth+zPW374Kxp99ietL2/+Lux2TYBVDWf3fitkua6nckvImJ58fenetsCULeeYY+IZyW9PYBeADSoygm6W20fKHbzF3V7k+0x2xO2JyosC0BF/Yb9l5K+IGm5pBOS7u/2xogYj4gVEbGiz2UBqEFfYY+IqYh4PyJmJP1K0nX1tgWgbn2F3faSWS+/Kelgt/cCGA49r7PbflzSSkmLbR+X9GNJK20vlxSSjkq6pcEe0cOmTZu61mZmZhpd9ubNmxudP+rTM+wRsWGOyY820AuABvF1WSAJwg4kQdiBJAg7kARhB5LgJ65ngeXLl5fWV69e3diyd+7cWVo/fPhwY8tGvdiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASDNl8Fpieni6tL1rU9a5gPT333HOl9RtuuKG0fvLkyb6XjWYwZDOQHGEHkiDsQBKEHUiCsANJEHYgCcIOJMHv2c8Cl156aWm9yu2iH3nkkdI619HPHWzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJrrMPgS1btpTWFyxo7v/kvXv3NjZvDJee/4psL7W9x/bLtl+y/YNi+iW2n7L9WvHY/x0UADRuPpuM9yT9MCKulvQVSd+3fbWkOyXtjoirJO0uXgMYUj3DHhEnIuKF4vm7kl6RdIWktZK2FW/bJmldU00CqO5jHbPbXibpWkn/lDQSESeK0puSRrp8ZkzSWP8tAqjDvM/82L5A0g5Jt0fEf2fXonPXyjlvJhkR4xGxIiJWVOoUQCXzCrvtT6oT9N9ExBPF5CnbS4r6Eknlt0AF0Kqeu/G2LelRSa9ExAOzSrskbZS0uXgsH9s3sV5DLq9ataq03usnrKdOnepae/jhh0s/OzU1VVrHuWM+x+xflfRtSS/a3l9Mu0udkP/e9nclHZO0vpkWAdShZ9gj4h+S5rzpvKTr620HQFP4uiyQBGEHkiDsQBKEHUiCsANJ8BPXAbj44otL65dddlml+U9OTnat3XHHHZXmjXMHW3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1Igt+zD8ChQ4dK672GTR4dHa2zHSTFlh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHknBElL/BXippu6QRSSFpPCIetH2PpO9J+k/x1rsi4k895lW+MACVRcScoy7PJ+xLJC2JiBdsXyjpeUnr1BmP/WRE/Hy+TRB2oHndwj6f8dlPSDpRPH/X9iuSrqi3PQBN+1jH7LaXSbpW0j+LSbfaPmD7MduLunxmzPaE7YlKnQKopOdu/AdvtC+Q9Iykn0bEE7ZHJL2lznH8T9TZ1f9Oj3mwGw80rO9jdkmy/UlJf5T0l4h4YI76Mkl/jIgv9pgPYQca1i3sPXfjbVvSo5JemR304sTdad+UdLBqkwCaM5+z8aOS/i7pRUkzxeS7JG2QtFyd3fijkm4pTuaVzYstO9CwSrvxdSHsQPP63o0HcG4g7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJDHoIZvfknRs1uvFxbRhNKy9DWtfEr31q87eruxWGOjv2T+ycHsiIla01kCJYe1tWPuS6K1fg+qN3XggCcIOJNF22MdbXn6ZYe1tWPuS6K1fA+mt1WN2AIPT9pYdwIAQdiCJVsJue43tw7Zft31nGz10Y/uo7Rdt7297fLpiDL1p2wdnTbvE9lO2Xyse5xxjr6Xe7rE9Way7/bZvbKm3pbb32H7Z9ku2f1BMb3XdlfQ1kPU28GN22+dJelXSNyQdl7RP0oaIeHmgjXRh+6ikFRHR+hcwbH9N0klJ208PrWX7Z5LejojNxX+UiyJi05D0do8+5jDeDfXWbZjxm9Xiuqtz+PN+tLFlv07S6xFxJCJOSfqdpLUt9DH0IuJZSW+fMXmtpG3F823q/GMZuC69DYWIOBERLxTP35V0epjxVtddSV8D0UbYr5D0xqzXxzVc472HpL/aft72WNvNzGFk1jBbb0oaabOZOfQcxnuQzhhmfGjWXT/Dn1fFCbqPGo2IL0u6QdL3i93VoRSdY7Bhunb6S0lfUGcMwBOS7m+zmWKY8R2Sbo+I/86utbnu5uhrIOutjbBPSlo66/VnimlDISImi8dpSU+qc9gxTKZOj6BbPE633M8HImIqIt6PiBlJv1KL664YZnyHpN9ExBPF5NbX3Vx9DWq9tRH2fZKusv0525+S9C1Ju1ro4yNsLyxOnMj2QkmrNXxDUe+StLF4vlHSzhZ7+ZBhGca72zDjanndtT78eUQM/E/Sjeqckf+3pB+10UOXvj4v6V/F30tt9ybpcXV26/6nzrmN70q6VNJuSa9J+pukS4aot1+rM7T3AXWCtaSl3kbV2UU/IGl/8Xdj2+uupK+BrDe+LgskwQk6IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUji/5/q50l6GREBAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(images[3], cmap='gray', interpolation='none')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Model Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_test_accuracy(model, X, y):    \n",
    "    output = model(X).get_plain_text().softmax(0)\n",
    "    predicted = output.argmax(1)\n",
    "    labels = y.get_plain_text().argmax(1)\n",
    "    correct = (predicted == labels).sum().float()\n",
    "    return float(correct / y.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8799999952316284"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "avg_test_accuracy(model, images_enc, labels_enc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a CNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "based on https://github.com/pytorch/examples/blob/master/mnist/main.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(crypten.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = crypten.nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = crypten.nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = crypten.nn.Dropout2d(0.25)\n",
    "        self.dropout2 = crypten.nn.Dropout2d(0.5)\n",
    "        self.fc1 = crypten.nn.Linear(9216, 128)\n",
    "        self.fc2 = crypten.nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        x = self.conv1(x)\n",
    "        x = x.relu()\n",
    "        x = self.conv2(x)\n",
    "        x = x.relu()\n",
    "        x = x.max_pool2d(2)\n",
    "        x = self.dropout1(x)\n",
    "        x = self.fc1(x)\n",
    "        x = x.relu()\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNN().encrypt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 28, 28])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<crypten.autograd_cryptensor.AutogradCrypTensor object at 0x131944990>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = images_enc[0].unsqueeze(0)\n",
    "print(x.shape)\n",
    "model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 loss: 2.3135986328125\n",
      "epoch 1 loss: 2.2531890869140625\n",
      "epoch 2 loss: 2.1912384033203125\n"
     ]
    }
   ],
   "source": [
    "model = train_model(model, images_enc[:10, ], labels_enc[:10,], epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = model(images_enc[3].unsqueeze(0)).argmax()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importing PyTorch Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class PyTorchModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout2d(0.25)\n",
    "        self.dropout2 = nn.Dropout2d(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_model = PyTorchModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<crypten.nn.module.Graph at 0x128d876d0>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dummy_input = torch.empty(images.shape)\n",
    "crypten_model = crypten.nn.from_pytorch(pytorch_model, dummy_input)\n",
    "crypten_model.encrypt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<crypten.autograd_cryptensor.AutogradCrypTensor object at 0x131b7f410>\n"
     ]
    }
   ],
   "source": [
    "prediction = crypten_model(images_enc[3].unsqueeze(0))\n",
    "print(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-2.2325, -2.2288, -2.2988, -2.3620, -2.2741, -2.2421, -2.4341, -2.2122,\n",
      "         -2.2984, -2.3620]])\n"
     ]
    }
   ],
   "source": [
    "print(prediction.get_plain_text())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(7)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction.get_plain_text().argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 loss: 2.3262786865234375\n",
      "epoch 1 loss: 2.287811279296875\n",
      "epoch 2 loss: 2.182342529296875\n"
     ]
    }
   ],
   "source": [
    "crypten_model = train_model(crypten_model, images_enc[:10, ], labels_enc[:10,], epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "crypten-pycon",
   "language": "python",
   "name": "crypten-pycon"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
